{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tutorial 9: Recurrent GNNs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this tutorial we will implement an approximation of the Graph Neural Network Model (without enforcing contraction map) and analyze the GatedGraph Convolution of Pytorch Geometric."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.11.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/antonio/anaconda3/envs/geometric_new/lib/python3.9/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import torch\n",
    "os.environ['TORCH'] = torch.__version__\n",
    "print(torch.__version__)\n",
    "\n",
    "!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html\n",
    "!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html\n",
    "!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os.path as osp\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch_geometric.transforms as T\n",
    "import torch_geometric\n",
    "from torch_geometric.datasets import Planetoid, TUDataset\n",
    "from torch_geometric.data import DataLoader\n",
    "from torch_geometric.nn.inits import uniform\n",
    "from torch.nn import Parameter as Param\n",
    "from torch import Tensor \n",
    "torch.manual_seed(42)\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "device = \"cpu\"\n",
    "from torch_geometric.nn.conv import MessagePassing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = 'Cora'\n",
    "transform = T.Compose([\n",
    "    T.RandomNodeSplit('train_rest', num_val=500, num_test=500),\n",
    "    T.TargetIndegree(),\n",
    "])\n",
    "path = osp.join('data', dataset)\n",
    "dataset = Planetoid(path, dataset, transform=transform)\n",
    "data = dataset[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = 'Cora'\n",
    "path = osp.join('data', dataset)\n",
    "dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures())\n",
    "data = dataset[0]\n",
    "data = data.to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Graph Neural Network Model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"./transition.png\" width=\"500\">"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"./output.png\" width=\"500\">"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The MLP class is used to instantiate the transition and output functions as simple feed forard networks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MLP(nn.Module):\n",
    "    def __init__(self, input_dim, hid_dims, out_dim):\n",
    "        super(MLP, self).__init__()\n",
    "\n",
    "        self.mlp = nn.Sequential()\n",
    "        dims = [input_dim] + hid_dims + [out_dim]\n",
    "        for i in range(len(dims)-1):\n",
    "            self.mlp.add_module('lay_{}'.format(i),nn.Linear(in_features=dims[i], out_features=dims[i+1]))\n",
    "            if i+2 < len(dims):\n",
    "                self.mlp.add_module('act_{}'.format(i), nn.Tanh())\n",
    "    def reset_parameters(self):\n",
    "        for i, l in enumerate(self.mlp):\n",
    "            if type(l) == nn.Linear:\n",
    "                nn.init.xavier_normal_(l.weight)\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.mlp(x)\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The GNNM calss puts together the state propagations and the readout of the nodes' states."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "class GNNM(MessagePassing):\n",
    "    def __init__(self, n_nodes, out_channels, features_dim, hid_dims, num_layers = 50, eps=1e-3, aggr = 'add',\n",
    "                 bias = True, **kwargs):\n",
    "        super(GNNM, self).__init__(aggr=aggr, **kwargs)\n",
    "\n",
    "        self.node_states = Param(torch.zeros((n_nodes, features_dim)), requires_grad=False)\n",
    "        self.out_channels = out_channels\n",
    "        self.eps = eps\n",
    "        self.num_layers = num_layers\n",
    "        \n",
    "        self.transition = MLP(features_dim, hid_dims, features_dim)\n",
    "        self.readout = MLP(features_dim, hid_dims, out_channels)\n",
    "        \n",
    "        self.reset_parameters()\n",
    "        print(self.transition)\n",
    "        print(self.readout)\n",
    "\n",
    "    def reset_parameters(self):\n",
    "        self.transition.reset_parameters()\n",
    "        self.readout.reset_parameters()\n",
    "        \n",
    "    def forward(self): \n",
    "        edge_index = data.edge_index\n",
    "        edge_weight = data.edge_attr\n",
    "        node_states = self.node_states\n",
    "        for i in range(self.num_layers):\n",
    "            m = self.propagate(edge_index, x=node_states, edge_weight=edge_weight,\n",
    "                               size=None)\n",
    "            new_states = self.transition(m)\n",
    "            with torch.no_grad():\n",
    "                distance = torch.norm(new_states - node_states, dim=1)\n",
    "                convergence = distance < self.eps\n",
    "            node_states = new_states\n",
    "            if convergence.all():\n",
    "                break\n",
    "            \n",
    "        out = self.readout(node_states)\n",
    "        \n",
    "        return F.log_softmax(out, dim=-1)\n",
    "\n",
    "    def message(self, x_j, edge_weight):\n",
    "        return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j\n",
    "\n",
    "    def message_and_aggregate(self, adj_t, x) :\n",
    "        return matmul(adj_t, x, reduce=self.aggr)\n",
    "\n",
    "    def __repr__(self):\n",
    "        return '{}({}, num_layers={})'.format(self.__class__.__name__,\n",
    "                                              self.out_channels,\n",
    "                                              self.num_layers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MLP(\n",
      "  (mlp): Sequential(\n",
      "    (lay_0): Linear(in_features=32, out_features=64, bias=True)\n",
      "    (act_0): Tanh()\n",
      "    (lay_1): Linear(in_features=64, out_features=64, bias=True)\n",
      "    (act_1): Tanh()\n",
      "    (lay_2): Linear(in_features=64, out_features=64, bias=True)\n",
      "    (act_2): Tanh()\n",
      "    (lay_3): Linear(in_features=64, out_features=64, bias=True)\n",
      "    (act_3): Tanh()\n",
      "    (lay_4): Linear(in_features=64, out_features=64, bias=True)\n",
      "    (act_4): Tanh()\n",
      "    (lay_5): Linear(in_features=64, out_features=32, bias=True)\n",
      "  )\n",
      ")\n",
      "MLP(\n",
      "  (mlp): Sequential(\n",
      "    (lay_0): Linear(in_features=32, out_features=64, bias=True)\n",
      "    (act_0): Tanh()\n",
      "    (lay_1): Linear(in_features=64, out_features=64, bias=True)\n",
      "    (act_1): Tanh()\n",
      "    (lay_2): Linear(in_features=64, out_features=64, bias=True)\n",
      "    (act_2): Tanh()\n",
      "    (lay_3): Linear(in_features=64, out_features=64, bias=True)\n",
      "    (act_3): Tanh()\n",
      "    (lay_4): Linear(in_features=64, out_features=64, bias=True)\n",
      "    (act_4): Tanh()\n",
      "    (lay_5): Linear(in_features=64, out_features=7, bias=True)\n",
      "  )\n",
      ")\n",
      "Epoch: 001, Train Acc: 0.12857, Val Acc: 0.06800, Test Acc: 0.08800\n",
      "Epoch: 002, Train Acc: 0.14286, Val Acc: 0.25200, Test Acc: 0.25100\n",
      "Epoch: 003, Train Acc: 0.12143, Val Acc: 0.24400, Test Acc: 0.26100\n",
      "Epoch: 004, Train Acc: 0.17143, Val Acc: 0.20200, Test Acc: 0.20200\n",
      "Epoch: 005, Train Acc: 0.16429, Val Acc: 0.23000, Test Acc: 0.23100\n",
      "Epoch: 006, Train Acc: 0.22857, Val Acc: 0.10000, Test Acc: 0.10500\n",
      "Epoch: 007, Train Acc: 0.14286, Val Acc: 0.11400, Test Acc: 0.10000\n",
      "Epoch: 008, Train Acc: 0.14286, Val Acc: 0.08400, Test Acc: 0.08000\n",
      "Epoch: 009, Train Acc: 0.14286, Val Acc: 0.06800, Test Acc: 0.06500\n",
      "Epoch: 010, Train Acc: 0.18571, Val Acc: 0.18600, Test Acc: 0.17000\n",
      "Epoch: 011, Train Acc: 0.08571, Val Acc: 0.07800, Test Acc: 0.07700\n",
      "Epoch: 012, Train Acc: 0.14286, Val Acc: 0.08000, Test Acc: 0.08500\n",
      "Epoch: 013, Train Acc: 0.11429, Val Acc: 0.06800, Test Acc: 0.07000\n",
      "Epoch: 014, Train Acc: 0.15714, Val Acc: 0.10200, Test Acc: 0.08600\n",
      "Epoch: 015, Train Acc: 0.16429, Val Acc: 0.20000, Test Acc: 0.17500\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Input \u001b[0;32mIn [16]\u001b[0m, in \u001b[0;36m<cell line: 28>\u001b[0;34m()\u001b[0m\n\u001b[1;32m     28\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m epoch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m51\u001b[39m):\n\u001b[1;32m     29\u001b[0m     train()\n\u001b[0;32m---> 30\u001b[0m     accs \u001b[38;5;241m=\u001b[39m \u001b[43mtest\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     31\u001b[0m     train_acc \u001b[38;5;241m=\u001b[39m accs[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m     32\u001b[0m     val_acc \u001b[38;5;241m=\u001b[39m accs[\u001b[38;5;241m1\u001b[39m]\n",
      "Input \u001b[0;32mIn [16]\u001b[0m, in \u001b[0;36mtest\u001b[0;34m()\u001b[0m\n\u001b[1;32m     18\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mtest\u001b[39m():\n\u001b[1;32m     19\u001b[0m     model\u001b[38;5;241m.\u001b[39meval()\n\u001b[0;32m---> 20\u001b[0m     logits, accs \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m, []\n\u001b[1;32m     21\u001b[0m     \u001b[38;5;28;01mfor\u001b[39;00m _, mask \u001b[38;5;129;01min\u001b[39;00m data(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtrain_mask\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mval_mask\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtest_mask\u001b[39m\u001b[38;5;124m'\u001b[39m):\n\u001b[1;32m     22\u001b[0m         pred \u001b[38;5;241m=\u001b[39m logits[mask]\u001b[38;5;241m.\u001b[39mmax(\u001b[38;5;241m1\u001b[39m)[\u001b[38;5;241m1\u001b[39m]\n",
      "File \u001b[0;32m~/anaconda3/envs/geometric_new/lib/python3.9/site-packages/torch/nn/modules/module.py:1110\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1106\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1107\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1108\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1109\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1110\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1111\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m   1112\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
      "Input \u001b[0;32mIn [15]\u001b[0m, in \u001b[0;36mGNNM.forward\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m     26\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_layers):\n\u001b[1;32m     27\u001b[0m     m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpropagate(edge_index, x\u001b[38;5;241m=\u001b[39mnode_states, edge_weight\u001b[38;5;241m=\u001b[39medge_weight,\n\u001b[1;32m     28\u001b[0m                        size\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m)\n\u001b[0;32m---> 29\u001b[0m     new_states \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtransition\u001b[49m\u001b[43m(\u001b[49m\u001b[43mm\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     30\u001b[0m     \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[1;32m     31\u001b[0m         distance \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mnorm(new_states \u001b[38;5;241m-\u001b[39m node_states, dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m)\n",
      "File \u001b[0;32m~/anaconda3/envs/geometric_new/lib/python3.9/site-packages/torch/nn/modules/module.py:1110\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1106\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1107\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1108\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1109\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1110\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1111\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m   1112\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
      "Input \u001b[0;32mIn [14]\u001b[0m, in \u001b[0;36mMLP.forward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m     16\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, x):\n\u001b[0;32m---> 17\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmlp\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/anaconda3/envs/geometric_new/lib/python3.9/site-packages/torch/nn/modules/module.py:1110\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1106\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1107\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1108\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1109\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1110\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1111\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m   1112\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
      "File \u001b[0;32m~/anaconda3/envs/geometric_new/lib/python3.9/site-packages/torch/nn/modules/container.py:141\u001b[0m, in \u001b[0;36mSequential.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m    139\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m):\n\u001b[1;32m    140\u001b[0m     \u001b[38;5;28;01mfor\u001b[39;00m module \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m:\n\u001b[0;32m--> 141\u001b[0m         \u001b[38;5;28minput\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[43mmodule\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m    142\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28minput\u001b[39m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "model = GNNM(data.num_nodes, dataset.num_classes, 32, [64,64,64,64,64], eps=0.01).to(device)\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n",
    "loss_fn = nn.CrossEntropyLoss()\n",
    "\n",
    "\n",
    "test_dataset = dataset[:len(dataset) // 10]\n",
    "train_dataset = dataset[len(dataset) // 10:]\n",
    "test_loader = DataLoader(test_dataset)\n",
    "train_loader = DataLoader(train_dataset)\n",
    "\n",
    "def train():\n",
    "    model.train()\n",
    "    optimizer.zero_grad()\n",
    "    loss_fn(model()[data.train_mask], data.y[data.train_mask]).backward()\n",
    "    optimizer.step()\n",
    "\n",
    "\n",
    "def test():\n",
    "    model.eval()\n",
    "    logits, accs = model(), []\n",
    "    for _, mask in data('train_mask', 'val_mask', 'test_mask'):\n",
    "        pred = logits[mask].max(1)[1]\n",
    "        acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()\n",
    "        accs.append(acc)\n",
    "    return accs\n",
    "\n",
    "\n",
    "for epoch in range(1, 51):\n",
    "    train()\n",
    "    accs = test()\n",
    "    train_acc = accs[0]\n",
    "    val_acc = accs[1]\n",
    "    test_acc = accs[2]\n",
    "    print('Epoch: {:03d}, Train Acc: {:.5f}, '\n",
    "          'Val Acc: {:.5f}, Test Acc: {:.5f}'.format(epoch, train_acc,\n",
    "                                                       val_acc, test_acc))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Gated Graph Neural Network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "class GatedGraphConv(MessagePassing):\n",
    "    \n",
    "    def __init__(self, out_channels, num_layers, aggr = 'add',\n",
    "                 bias = True, **kwargs):\n",
    "        super(GatedGraphConv, self).__init__(aggr=aggr, **kwargs)\n",
    "\n",
    "        self.out_channels = out_channels\n",
    "        self.num_layers = num_layers\n",
    "\n",
    "        self.weight = Param(Tensor(num_layers, out_channels, out_channels))\n",
    "        self.rnn = torch.nn.GRUCell(out_channels, out_channels, bias=bias)\n",
    "\n",
    "        self.reset_parameters()\n",
    "\n",
    "    def reset_parameters(self):\n",
    "        uniform(self.out_channels, self.weight)\n",
    "        self.rnn.reset_parameters()\n",
    "\n",
    "    def forward(self, data):\n",
    "        \"\"\"\"\"\"\n",
    "        x = data.x\n",
    "        edge_index = data.edge_index\n",
    "        edge_weight = data.edge_attr\n",
    "        if x.size(-1) > self.out_channels:\n",
    "            raise ValueError('The number of input channels is not allowed to '\n",
    "                             'be larger than the number of output channels')\n",
    "\n",
    "        if x.size(-1) < self.out_channels:\n",
    "            zero = x.new_zeros(x.size(0), self.out_channels - x.size(-1))\n",
    "            x = torch.cat([x, zero], dim=1)\n",
    "\n",
    "        for i in range(self.num_layers):\n",
    "            m = torch.matmul(x, self.weight[i])\n",
    "            m = self.propagate(edge_index, x=m, edge_weight=edge_weight,\n",
    "                               size=None)\n",
    "            x = self.rnn(m, x)\n",
    "\n",
    "        return x\n",
    "\n",
    "    def message(self, x_j, edge_weight):\n",
    "        return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j\n",
    "\n",
    "    def message_and_aggregate(self, adj_t, x):\n",
    "        return matmul(adj_t, x, reduce=self.aggr)\n",
    "\n",
    "    def __repr__(self):\n",
    "        return '{}({}, num_layers={})'.format(self.__class__.__name__,\n",
    "                                              self.out_channels,\n",
    "                                              self.num_layers)\n",
    "\n",
    "class GGNN(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super(GGNN, self).__init__()\n",
    "        \n",
    "        self.conv = GatedGraphConv(1433, 3)\n",
    "        self.mlp = MLP(1433, [32,32,32], dataset.num_classes)\n",
    "        \n",
    "    def forward(self):\n",
    "        x = self.conv(data)\n",
    "        x = self.mlp(x)\n",
    "        return F.log_softmax(x, dim=-1)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 001, Train Acc: 0.27143, Val Acc: 0.15800, Test Acc: 0.15400\n",
      "Epoch: 002, Train Acc: 0.35000, Val Acc: 0.22200, Test Acc: 0.22200\n",
      "Epoch: 003, Train Acc: 0.18571, Val Acc: 0.22400, Test Acc: 0.21000\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Input \u001b[0;32mIn [18]\u001b[0m, in \u001b[0;36m<cell line: 29>\u001b[0;34m()\u001b[0m\n\u001b[1;32m     26\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m accs\n\u001b[1;32m     29\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m epoch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m51\u001b[39m):\n\u001b[0;32m---> 30\u001b[0m     \u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     31\u001b[0m     accs \u001b[38;5;241m=\u001b[39m test()\n\u001b[1;32m     32\u001b[0m     train_acc \u001b[38;5;241m=\u001b[39m accs[\u001b[38;5;241m0\u001b[39m]\n",
      "Input \u001b[0;32mIn [18]\u001b[0m, in \u001b[0;36mtrain\u001b[0;34m()\u001b[0m\n\u001b[1;32m     13\u001b[0m model\u001b[38;5;241m.\u001b[39mtrain()\n\u001b[1;32m     14\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mzero_grad()\n\u001b[0;32m---> 15\u001b[0m loss_fn(\u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m[data\u001b[38;5;241m.\u001b[39mtrain_mask], data\u001b[38;5;241m.\u001b[39my[data\u001b[38;5;241m.\u001b[39mtrain_mask])\u001b[38;5;241m.\u001b[39mbackward()\n\u001b[1;32m     16\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mstep()\n",
      "File \u001b[0;32m~/anaconda3/envs/geometric_new/lib/python3.9/site-packages/torch/nn/modules/module.py:1110\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1106\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1107\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1108\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1109\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1110\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1111\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m   1112\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
      "Input \u001b[0;32mIn [17]\u001b[0m, in \u001b[0;36mGGNN.forward\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m     58\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[0;32m---> 59\u001b[0m     x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mconv\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     60\u001b[0m     x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmlp(x)\n\u001b[1;32m     61\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m F\u001b[38;5;241m.\u001b[39mlog_softmax(x, dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\n",
      "File \u001b[0;32m~/anaconda3/envs/geometric_new/lib/python3.9/site-packages/torch/nn/modules/module.py:1110\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1106\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1107\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1108\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1109\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1110\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1111\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m   1112\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
      "Input \u001b[0;32mIn [17]\u001b[0m, in \u001b[0;36mGatedGraphConv.forward\u001b[0;34m(self, data)\u001b[0m\n\u001b[1;32m     33\u001b[0m     m \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mmatmul(x, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mweight[i])\n\u001b[1;32m     34\u001b[0m     m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpropagate(edge_index, x\u001b[38;5;241m=\u001b[39mm, edge_weight\u001b[38;5;241m=\u001b[39medge_weight,\n\u001b[1;32m     35\u001b[0m                        size\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m)\n\u001b[0;32m---> 36\u001b[0m     x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrnn\u001b[49m\u001b[43m(\u001b[49m\u001b[43mm\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     38\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m x\n",
      "File \u001b[0;32m~/anaconda3/envs/geometric_new/lib/python3.9/site-packages/torch/nn/modules/module.py:1110\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1106\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m   1107\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m   1108\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m   1109\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1110\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1111\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m   1112\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
      "File \u001b[0;32m~/anaconda3/envs/geometric_new/lib/python3.9/site-packages/torch/nn/modules/rnn.py:1267\u001b[0m, in \u001b[0;36mGRUCell.forward\u001b[0;34m(self, input, hx)\u001b[0m\n\u001b[1;32m   1264\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m   1265\u001b[0m     hx \u001b[38;5;241m=\u001b[39m hx\u001b[38;5;241m.\u001b[39munsqueeze(\u001b[38;5;241m0\u001b[39m) \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_batched \u001b[38;5;28;01melse\u001b[39;00m hx\n\u001b[0;32m-> 1267\u001b[0m ret \u001b[38;5;241m=\u001b[39m \u001b[43m_VF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgru_cell\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   1268\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhx\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1269\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight_ih\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight_hh\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1270\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias_ih\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias_hh\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1271\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1273\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_batched:\n\u001b[1;32m   1274\u001b[0m     ret \u001b[38;5;241m=\u001b[39m ret\u001b[38;5;241m.\u001b[39msqueeze(\u001b[38;5;241m0\u001b[39m)\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "device = \"cpu\"\n",
    "model = GGNN().to(device)\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n",
    "loss_fn = nn.CrossEntropyLoss()\n",
    "\n",
    "\n",
    "test_dataset = dataset[:len(dataset) // 10]\n",
    "train_dataset = dataset[len(dataset) // 10:]\n",
    "test_loader = DataLoader(test_dataset)\n",
    "train_loader = DataLoader(train_dataset)\n",
    "\n",
    "def train():\n",
    "    model.train()\n",
    "    optimizer.zero_grad()\n",
    "    loss_fn(model()[data.train_mask], data.y[data.train_mask]).backward()\n",
    "    optimizer.step()\n",
    "\n",
    "\n",
    "def test():\n",
    "    model.eval()\n",
    "    logits, accs = model(), []\n",
    "    for _, mask in data('train_mask', 'val_mask', 'test_mask'):\n",
    "        pred = logits[mask].max(1)[1]\n",
    "        acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()\n",
    "        accs.append(acc)\n",
    "    return accs\n",
    "\n",
    "\n",
    "for epoch in range(1, 51):\n",
    "    train()\n",
    "    accs = test()\n",
    "    train_acc = accs[0]\n",
    "    val_acc = accs[1]\n",
    "    test_acc = accs[2]\n",
    "    print('Epoch: {:03d}, Train Acc: {:.5f}, '\n",
    "          'Val Acc: {:.5f}, Test Acc: {:.5f}'.format(epoch, train_acc,\n",
    "                                                       val_acc, test_acc))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
